import matplotlib.pyplot as plt

files = {
    'mlp_kaiming_res'   : './output/mlp_kaiming_res/train_log.txt',
    'mlp_xavier_res'    : './output/mlp_xavier_res/train_log.txt'
}

loss_data = {key: [] for key in files}

for key, file_path in files.items():
    with open(file_path, 'r') as file:
        for line in file:
            if 'Loss:' in line:
                loss = float(line.split('Loss: ')[1])
                loss_data[key].append(loss)

plt.figure(figsize=(10, 6))

for key, losses in loss_data.items():
    epochs = range(1, len(losses) + 1)
    plt.plot(epochs, losses, label=key)

plt.title('Training Loss Curves')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
